import requests
from webserver.util import query_util
import logging

from database import (
    fix_ids,
    ImageModel,
    CategoryModel,
    AnnotationModel,
    DatasetModel,
    TaskModel,
    ExportModel
)

from celery import shared_task
from ..socket import create_socket
from mongoengine import Q

# from workers import celery

logger = logging.getLogger('gunicorn.error')


@shared_task
def batch(dataset_id, category_ids, metadata, box_treshold, text_treshold):
    socket = create_socket()

    image_models = ImageModel.objects(dataset_id=dataset_id)

    # 将category_id和category结合成字典发送给算法
    category_dict = {}

    for category_id in category_ids:
        category = CategoryModel.objects(id=category_id, deleted=False)
        # category_names.append(category.name)
        category_dict[category_id] = category.name

    # 首先根据image_ids获取image数组
    logger.info("==========")

    for image_model in image_models:
        with open(image_model.path, 'rb') as image:
            # images.append(image)
            # 第一版本：一张图片发一次请求
            # 发送post请求
            url = 'http://125.220.157.228:29994/'

            # params = {'class_names': [category_name for category_name in category_names], 'box_treshold': box_treshold,
            #           'text_treshold': text_treshold}

            data = {'class_names_dict': category_dict, 'box_treshold': box_treshold,
                    'text_treshold': text_treshold}

            files = {'image': image}

            # 发送 POST 请求
            r = requests.post(url, data=data, files=files)
            response = r.json()

            # TODO 要改成二维数组的形式
            annotations = response['annotations']
            new_annotations = []

            for annotation in annotations:
                category_id = annotation['category_id']
                segmentation = annotation['segmentation']
                area = annotation['area']
                bbox = annotation['bbox']
                try:
                    new_annotation = AnnotationModel(
                        image_id=image_model.id,
                        category_id=category_id,
                        metadata=metadata,
                        segmentation=segmentation,
                        bbox=bbox,
                        area=area,
                    )
                    new_annotation.save()
                    new_annotations.append(query_util.fix_ids(new_annotation))
                except (ValueError, TypeError) as e:
                    # return {'message': str(e)}, 400
                    print(str(e))

            res = {
                "annotations": new_annotations
            }

    # return response


__all__ = ["batch"]
